import json
import zipfile

from hls4ml.model.optimizer.optimizer import ConfigurableOptimizerPass, ModelOptimizerPass


def initialize_large_fifos(model, profiling_fifo_depth):
    """Set all FIFO depths equal to a large value so that they can be profiled.

    Args:
        model (ModelGraph): The model to which FIFO depth optimization is applied.
        profiling_fifo_depth (int): A large non-negative integer, must be larger than the max expected depth of the FIFOs.

    Returns:
        Dict[str, int]: A dictionary containing FIFO names as keys and their initial depths as values is returned for
        comparison with the optimized depths.
    """

    # filter all the output variables and keep only the internal FIFOs, excluding output objects that are not FIFOs and the
    # input and output FIFOs as they can't be profiled and are implementation dependant i.e AXI Stream, AXI Master or
    # connected to another IP
    vars_to_profile = {
        output_variable_name: output_variable
        for output_variable_name, output_variable in model.output_vars.items()
        if ('StreamVariable' in str(type(output_variable)))
        and output_variable != model.get_output_variables()[0]
        and output_variable != model.get_input_variables()[0]
    }

    # initialize all the fifos to `profiling_fifo_depth` so that they will be automatically implemented in BRAMs and so
    # they will be profiled. Alternatively, "config_dataflow -override_user_fifo_depth profiling_fifo_depth" can be
    # used inside build_prj.tcl to override all FIFO depths with the specified value
    initial_fifo_depths = {}
    for output_variable in vars_to_profile.values():
        if output_variable.pragma:
            initial_fifo_depths[output_variable.name] = int(output_variable.pragma[1])
            output_variable.pragma = (output_variable.pragma[0], profiling_fifo_depth)
    return initial_fifo_depths


def execute_cosim_to_profile_fifos(model):
    """Execute a co-simulation with a test-bench that calls the top function to properly profile the max FIFO depths.
    Note that the top function needs to execute **least twice**, so user-provided input must have at least two samples.

    Args:
        model (ModelGraph): The model to which FIFO depth optimization is applied.
    """
    model.write()

    model.build(
        reset=False,
        csim=False,
        synth=True,
        cosim=True,
        validation=False,
        export=False,
        vsynth=False,
        fifo_opt=True,
    )


def get_vitis_optimized_fifo_depths(model):
    """Parse the files generated by the co-simulation to retrieve the optimized depths for the FIFOs.
    Attention, only the FIFOs between the layers are profiled!

    Args:
        model (ModelGraph): The model to which FIFO depth optimization is applied.

    Returns:
        Dict[str, int]: A dictionary that contains the FIFO names as keys and the optimized depths as values.
    """
    # channel.zip is generated after the co-simulation and contains the chan_status*.csv files
    # in the chan_status*.csv files the max depth achieved during co-simulation can be found at the last (4th) line
    path_to_zip_file = (
        model.config.get_output_dir()
        + '/'
        + model.config.get_project_name()
        + '_prj'
        + '/solution1/.autopilot/db/channel_depth_info/'
    )

    with zipfile.ZipFile(f'{path_to_zip_file}channel.zip', 'r') as zip_ref:
        zip_ref.extractall(path_to_zip_file)

    # the channel_info.csv file contains the mapping of each fifo name (i.e layer4_out_U) to the respective
    # chan_status*.csv file
    names_file_path = (
        model.config.get_output_dir()
        + '/'
        + model.config.get_project_name()
        + '_prj'
        + '/solution1/.autopilot/db/channel_info.csv'
    )

    csv_fifo_depth_files = {}
    with open(names_file_path) as names_file:
        for line in names_file:
            layer_name = line.split(',')[1]
            csv_file_name = line.split(',')[3][:-1]
            csv_fifo_depth_files[layer_name] = csv_file_name

    optmized_fifo_depths = {}
    for layer_name, file_name in csv_fifo_depth_files.items():
        with open(path_to_zip_file + file_name) as chan_status_file:
            lines = chan_status_file.readlines()
            optmized_fifo_depths[layer_name[:-2]] = int(
                lines[-1]
            )  # remove "_U" from the layer name string and keep the last line of the file that contains the max depth

    return optmized_fifo_depths


def generate_depths_file(model, initial_fifo_depths, optimized_fifo_depths):
    """Generate a json file with the names of the FIFOs, the initial depths set by hls4ml and their optimized depths,
    for post-processing. The json file is not used by the rest of the pipeline, it is only produced for the user.

    Args:
        model (ModelGraph): The model to which FIFO depth optimization is applied.
        initial_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the initial
        depths as values.
        optimized_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the optimized
        depths as values.
    """
    depths = {}
    for fifo_name in initial_fifo_depths.keys():
        depths[fifo_name] = {}
        depths[fifo_name]['initial'] = initial_fifo_depths[fifo_name]
        depths[fifo_name]['optimized'] = optimized_fifo_depths[fifo_name]

    with open(model.config.get_output_dir() + '/fifo_depths.json', 'w') as f:
        json.dump(depths, f, indent=4)


def set_optimized_fifo_depths(model, optimized_fifo_depths):
    """Set the new optimized FIFO depths.

    Args:
        model (ModelGraph): The model to which FIFO depth optimization is applied.
        optimized_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the optimized
        depths as values.
    """

    # iterate through the layer output FIFOs
    for output_variable in model.output_vars.values():
        if 'StreamVariable' in str(type(output_variable)):
            if output_variable.pragma:

                if output_variable.name not in optimized_fifo_depths.keys():
                    continue

                filtered_depth = optimized_fifo_depths[output_variable.name]
                output_variable.pragma = (output_variable.pragma[0], filtered_depth)


class FifoDepthOptimization(ConfigurableOptimizerPass, ModelOptimizerPass):
    def __init__(self):
        # use `profiling_fifo_depth = 0` to keep the default fifo depth
        # consider changing 100_000 either with a very very large value > of any total bram storage space
        # or via vitis 2023.2 c-simulation
        self.profiling_fifo_depth = 100_000

    def transform(self, model):
        """Perform FIFO depth optimization between the FIFOs of all layers to reduce resource utilization as the
        initial FIFOs set by hls4ml might be larger than required. At the end of the optimization the FIFOs will
        have the largest depths achieved during co-simulation without causing any deadlocks between the layers
        (producer-consumer), thus no additional delays between the layers. In some cases, this optimization
        might lead to bigger FIFOs than initially set by the hls4ml tool in order to prevent deadlocks.

        Args:
            model (ModelGraph): The model to which FIFO depth optimization is applied.

        Raises:
            ValueError: If the FIFO depth for profiling provided by the user is not a non-negative integer.
            RuntimeError: If the IO type is not set to "io_stream".

        Returns:
            bool: The execution state of the Optimizer Pass
        """

        if not isinstance(self.profiling_fifo_depth, int) or self.profiling_fifo_depth <= 0:
            raise ValueError('The FIFO depth for profiling (profiling_fifo_depth variable) must be a non-negative integer.')

        # check axi-stream or io-stream
        if not (model.config.get_config_value('IOType') == 'io_stream'):
            raise RuntimeError('To use this optimization you have to set `IOType` field to `io_stream` in the HLS config.')

        initial_fifo_depths = initialize_large_fifos(model, self.profiling_fifo_depth)
        execute_cosim_to_profile_fifos(model)
        optimized_fifo_depths = get_vitis_optimized_fifo_depths(model)
        generate_depths_file(model, initial_fifo_depths, optimized_fifo_depths)
        set_optimized_fifo_depths(model, optimized_fifo_depths)

        print('FIFO optimization completed')

        return False
